-
Notifications
You must be signed in to change notification settings - Fork 40
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
313 save final model #340
313 save final model #340
Conversation
Added final epoch number to file name - changed implementation to: # Always save final model weights at the end of training
if self.config.model_save_folder_path is not None:
self.trainer.save_checkpoint(
os.path.join(
self.config.model_save_folder_path,
f"train-run-final-{self.trainer.current_epoch}.ckpt"
)
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some minor comments to address.
However, can't we get the same results much easier by setting enable_checkpointing
always to True when creating the Trainer
and adding a default ModelCheckPoint
callback? See the Trainer documentation. That way we can benefit from letting Lightning handle all of this.
Agreed - I've reimplemented this using a |
Reimplemented using # Configure checkpoints.
self.callbacks = [
ModelCheckpoint(
dirpath=config.model_save_folder_path,
save_on_train_epoch_end=True,
)
]
if config.save_top_k is not None:
self.callbacks.append(
ModelCheckpoint(
dirpath=config.model_save_folder_path,
monitor="valid_CELoss",
mode="min",
save_top_k=config.save_top_k,
)
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great, I think that this is a better solution.
The remaining thing to do is add unit tests that verify that the final model is saved for different situations (different values of steps and epochs). You could also add some tests to check that the periodic checkpoints are properly created as well.
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## dev #340 +/- ##
==========================================
+ Coverage 89.77% 89.88% +0.10%
==========================================
Files 12 12
Lines 929 929
==========================================
+ Hits 834 835 +1
+ Misses 95 94 -1 ☔ View full report in Codecov by Sentry. |
Sounds good, I added some unit test that ensure the last model weights are saved in the scenarios where val_check_Interval is greater than and not a factor of the number of training steps. Unfortunately since the ModelCheckpoint that saves the model checkpoints at the end of evert epoch deletes the last training epoch's checkpoints (it doesn't touch the validation checkpoints) when a new epoch checkpoint is saved and the CLIRunner.invoke is blocking I couldn't think of a practical way to test whether the periodic checkpoints are saved properly. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few suggestions for the tests.
Sounds good, I factored the save final model test into a separate unit test |
Great! The final thing to do is update the changelog. |
Sounds great, I added an entry to the changelog. |
In order to insure the final model is always saved the following lines were added to the end of ModelRunner.train:
This implementation was tested using a small training run for the case where val_check_interval is not a factor of the total number of training steps and the case where val_check_interval is greater than the total number of training steps. In both cases the final model checkpoints were saved.